{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python_defaultSpec_1599565370251",
   "display_name": "Python 3.7.9 64-bit ('d2l': conda)"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "from d2l import torch as d2l\n",
    "import torch\n",
    "import numpy as np\n",
    "from torch.utils import data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "features: tensor([1.1136e+00, 9.3910e-04]) \nlabel: tensor([[ 6.4325],\n        [-4.7369],\n        [ 2.5370]])\nlabels.shape: torch.Size([1000, 1])\n"
    }
   ],
   "source": [
    "true_w = torch.tensor([2, -3.4])\n",
    "true_b = 4.2\n",
    "\n",
    "# y=Xw+b+ϵ\n",
    "features, labels = d2l.synthetic_data(true_w, true_b, 1000)\n",
    "print('features:', features[0],'\\nlabel:', labels[0:3]) #labels 只有1列\n",
    "print('labels.shape:',labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 10\n",
    "data_iter = d2l.load_array((features, labels), batch_size)\n",
    "#next(iter(data_iter))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "# 输入2个特征，输出1个标签\n",
    "net = nn.Sequential(nn.Linear(2, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "tensor([0.])"
     },
     "metadata": {},
     "execution_count": 37
    }
   ],
   "source": [
    "# 初始化net参数\n",
    "net[0].weight.data.normal_(0, 0.01)\n",
    "net[0].bias.data.fill_(0) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "epoch 1, loss 0.000235\nepoch 2, loss 0.000096\nepoch 3, loss 0.000096\n"
    }
   ],
   "source": [
    "lr = 0.03 #学习率\n",
    "loss = nn.MSELoss() #定义loss\n",
    "trainer = torch.optim.SGD(net.parameters(), lr)\n",
    "\n",
    "num_epochs = 3\n",
    "for epoch in range(num_epochs):\n",
    "    for X, y in data_iter:\n",
    "        l = loss(net(X) ,y)\n",
    "        trainer.zero_grad()\n",
    "        l.backward()\n",
    "        trainer.step()\n",
    "    l = loss(net(features), labels)\n",
    "    print(f'epoch {epoch + 1}, loss {l:f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "error in estimating w: tensor([-0.0003,  0.0002])\nerror in estimating b: tensor([-5.9605e-05])\n"
    }
   ],
   "source": [
    "w = net[0].weight.data\n",
    "print('error in estimating w:', true_w - w.reshape(true_w.shape))\n",
    "b = net[0].bias.data\n",
    "print('error in estimating b:', true_b - b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "Sequential(\n  (0): Linear(in_features=2, out_features=1, bias=True)\n)\n"
    }
   ],
   "source": [
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "Parameter containing:\ntensor([[ 2.0003, -3.4002]], requires_grad=True)\n"
    }
   ],
   "source": [
    "print(net[0].weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}